{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "d8013313",
   "metadata": {},
   "source": [
    "## GReaT Example with California Housing Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb5ba9dc-4372-45ff-80bb-e63fa99e1aa5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Execute only once!\n",
    "import os\n",
    "import sys\n",
    "sys.path.append(\"..\")\n",
    "os.chdir(\"..\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce19c02-2b33-4f76-bdb5-01889292f6a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "058f3bff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import logging\n",
    "from sklearn import datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f78e5ef1-1180-4e81-878c-f2e248226795",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import set_logging_level\n",
    "from be_great import GReaT"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd831932",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78fcfe3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "logger = set_logging_level(logging.INFO)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "95782c58",
   "metadata": {},
   "source": [
    "### Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fe4f469",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = datasets.fetch_california_housing(as_frame=True).frame\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39ff376c",
   "metadata": {},
   "source": [
    "### Create GReaT Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e10bf87c-457d-4f9d-a26a-12bdbb4cb9b1",
   "metadata": {},
   "source": [
    "Only one epoch here for demonstration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee720651",
   "metadata": {},
   "outputs": [],
   "source": [
    "great = GReaT(\"distilgpt2\",                         # Name of the large language model used (see HuggingFace for more options)\n",
    "              epochs=1,                             # Number of epochs to train (only one epoch for demonstration)\n",
    "              save_steps=2000,                      # Save model weights every x steps\n",
    "              logging_steps=50,                     # Log the loss and learning rate every x steps\n",
    "              experiment_dir=\"trainer_california\",  # Name of the directory where all intermediate steps are saved\n",
    "              #lr_scheduler_type=\"constant\",        # Specify the learning rate scheduler \n",
    "              #learning_rate=5e-5                   # Set the inital learning rate\n",
    "             )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b19304d",
   "metadata": {},
   "source": [
    "### Start Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7173938",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "trainer = great.fit(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0554c36",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_hist = trainer.state.log_history.copy()\n",
    "loss_hist.pop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ccc84e57",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss = [x[\"loss\"] for x in loss_hist]\n",
    "epochs = [x[\"epoch\"] for x in loss_hist]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc37dd63",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(epochs, loss)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "744176a5",
   "metadata": {},
   "source": [
    "### Save Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7305e71",
   "metadata": {},
   "outputs": [],
   "source": [
    "great.save(\"california\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "396c6731",
   "metadata": {},
   "source": [
    "### Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa91661d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# great = GReaT.load_from_dir(\"california\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1612a7b2-a060-407c-80d8-bab46a96686a",
   "metadata": {},
   "outputs": [],
   "source": [
    "great.load_finetuned_model(\"../great_private/models/california/california_distilgpt2_100.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5a9ca066",
   "metadata": {},
   "source": [
    "### Generate Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b62cfebf",
   "metadata": {},
   "outputs": [],
   "source": [
    "n_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95b3f31d",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples = great.sample(n_samples, k=50, device=\"cuda:1\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dff8c01d",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db334053-1166-409c-acb4-7bdf2243a836",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc0bfd6c",
   "metadata": {},
   "outputs": [],
   "source": [
    "samples.to_csv(\"california_samples.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9e2cd15",
   "metadata": {},
   "source": [
    "## Plot Data"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8c529095",
   "metadata": {},
   "source": [
    "Original Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ef00f79-cd0f-4e27-b11b-61f46bc37949",
   "metadata": {},
   "outputs": [],
   "source": [
    "true_samples = data.sample(n = 1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d086807b",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(true_samples[\"Longitude\"], true_samples[\"Latitude\"], c=true_samples[\"MedHouseVal\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "20d014af",
   "metadata": {},
   "source": [
    "Generated samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc661470",
   "metadata": {},
   "outputs": [],
   "source": [
    "#samples = pd.read_csv(\"california_samples.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "54b5652d",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(samples[\"Longitude\"], samples[\"Latitude\"], c=samples[\"MedHouseVal\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "967789ca",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e86e1320",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03ca2c0e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
